# -*- coding: utf-8 -*-
"""
Created on Tue Apr  5 12:14:22 2022

"""
import numpy as np 
import pandas as pd
#import networkx as nx
import matplotlib as mpl
mpl.use('Agg')
import matplotlib.pyplot as plt
import random
import pickle
import os
import shutil
from torch.autograd import Variable
import torch
import torch.nn as nn
import torch.nn.functional as F
from data_utils import PrepareConcDataset

plt.rcParams["font.family"] = "serif"
plt.rcParams["font.serif"] = ["Times New Roman"]
mpl.rcParams['axes.unicode_minus'] = False
plt.rcParams.update({"font.size":16})
plt.rcParams['xtick.direction'] = 'in'  # in; out; inout
plt.rcParams['ytick.direction'] = 'in'

#%%
def write_pkl(write_data,pkl_path):
	pickle.dump(write_data,open(pkl_path,'wb'))

def read_pkl(pkl_file_path):
	my_data=pickle.load(open(pkl_file_path,'rb'))
	return my_data
def make_dir(dir_path):
	if not os.path.exists(dir_path):
		os.makedirs(dir_path)

#%%filter png
def filter_png(ele:str):
    return 'png' not in ele


def read_logs(log_path):
    
    log_list=os.walk(log_path)
    log_list=list(map(list,iter(log_list)))
    log_list.remove(log_list[0])
    
    log_dirs=[]
    model_names=[]
    model_dirs=[]
    args_all=[]
    args_dirs=[]
    events=[]
    events_dirs=[]
    orders=[]
    
    for i in range(len(log_list)):
        log_dirs.append(log_list[i][0])
        log_list[i][2]=list(filter(filter_png,log_list[i][2]))
        if len(log_list[i][2])<3:
            shutil.rmtree(log_list[i][0])
        else:
            model_names.append(log_list[i][2][0])
            args_all.append(log_list[i][2][1])
            events.append(log_list[i][2][2])
        
            model_dirs.append(log_list[i][0]+"/"+log_list[i][2][0])
            args_dirs.append(log_list[i][0]+"/"+log_list[i][2][1])
            events_dirs.append(log_list[i][0]+"/"+log_list[i][2][2])
            orders.append(i)
    
    dirs={
          "log_dirs":log_dirs,
          "model_dirs":model_dirs,
          "args_dirs":args_dirs,
          "events_dirs":events_dirs,
          "orders":orders
          }
    
    return dirs



def plot_timeseries(x,y,name):
    plt.figure()
    plt.plot(x,y,'ob-')
    # plt.ylim((90, 650))
    plt.title(f"{name} concentration time series")
    plt.xlabel('time/s')
    plt.ylabel('Concentration')
    path=f'{name}.png'
    plt.savefig(path, bbox_inches='tight')
    plt.clf()
    plt.close()


def plot_loss(y1,y4,path,name):
        
    fig = plt.figure(figsize=(8,4))
    
    ax1 = fig.add_subplot(121)
    ax1.set_yscale('log')
    ax1.plot(y1,'b-',label='training loss')
    plt.xlabel('iteration')
    plt.ylabel('MSE Loss')
    plt.legend()
    
    ax4 = fig.add_subplot(122)
    ax4.plot(y4,'g-',label='l1 validation loss')
    plt.xlabel('iteration')
    plt.ylabel('l1 Loss')
    plt.legend()
    
    path=path+f'/{name}.png'
    plt.savefig(path, bbox_inches='tight')
    plt.clf()
    plt.close()


#%%
def plot_head(args,model,train_input,train_label,train_label_mask,
                val_input,val_label,val_label_mask,heads_name_day,
                dataset_name,path,device,dev,mmin,edge_list):
    
    index_node=np.array(edge_list[:,1:],dtype=np.int32)
    model.eval()
    total_train_outputs=[]
    for i in range(0,train_input.shape[0],args['batch_size']):
        inputs, labels,mask = train_input[i:i+args['batch_size']],train_label[i:i+args['batch_size']],train_label_mask[i:i+args['batch_size']]
        inputs, labels,mask = Variable(inputs.to(device=device)), Variable(labels.to(device=device)),Variable(mask.to(device=device))
        outputs= model(inputs)
        #B_T_C_H_W->B_T_C_N->N_B_T_F->numpy
        total_train_outputs.append((outputs[:,:,:,index_node[:,0],index_node[:,1]].permute(3,0,1,2)[:20,:,:,:].cpu().detach().numpy()*dev)+mmin)
    
    # numpy->N_B_T_F
    total_train_outputs=np.concatenate(total_train_outputs,axis=1)


    total_val_outputs=[]
    for i in range(0,val_input.shape[0],args['batch_size']):
        inputs, labels,mask = val_input[i:i+args['batch_size']],val_label[i:i+args['batch_size']],val_label_mask[i:i+args['batch_size']]
        inputs, labels,mask = Variable(inputs.to(device=device)), Variable(labels.to(device=device)),Variable(mask.to(device=device))
        outputs= model(inputs)
        #B_T_C_H_W->B_T_C_N->N_B_T_F->numpy
        total_val_outputs.append((outputs[:,:,:,index_node[:,0],index_node[:,1]].permute(3,0,1,2)[:20,:,:,:].cpu().detach().numpy()*dev)+mmin)
    
    # numpy->N_B_T_F->N_B_T
    total_val_outputs=np.concatenate(total_val_outputs,axis=1).squeeze()

    #B_T_C_H_W->B_T_C_N->B_T_N->N_B_T
    plot_train_t=train_input[:,:20,:,0].permute(1,0,2).numpy()
    
    #B_T_C_H_W->B_T_C_N->B_T_N->N_B_T
    plot_val_t=val_input[:,:20,:,0].permute(1,0,2).numpy()
    
    #B_T_C_H_W->B_T_C_N->B_T_N->N_B_T,
    plot_train_label=train_label[:,:20,:,0].permute(1,0,2).numpy()
    

    plot_train_label_mask=train_label_mask[:,:20,:,0].permute(1,0,2).numpy()
    
    #B_T_C_H_W->B_T_C_N->B_T_N->N_B_T,
    plot_val_label=val_label[:,:20,:,0].permute(1,0,2).numpy()

    plot_val_label_mask=val_label_mask[:,:20,:,0].permute(1,0,2).numpy()
    
    for i in range(0,plot_train_label.shape[0],9):
        fig,axs = plt.subplots(3,3,figsize=(18,18),dpi=200,sharex=True, sharey=True)
        axs=axs.reshape(-1)
        
        for j in range(9):
            for k in range(total_val_outputs.shape[1]):
                if k==0:
                    axs[j].plot(plot_train_t[i+j,k,:][plot_train_label_mask[i+j,k,:]],
                             plot_train_label[i+j,k,:][plot_train_label_mask[i+j,k,:]],
                             'g-',label=heads_name_day[i+j]+'_real',markersize=4)
                    
                    axs[j].plot(plot_train_t[i+j,k,:],total_train_outputs[i+j,k,:],
                             'b--',label=heads_name_day[i+j]+'_train')
                    
                    axs[j].plot(plot_val_t[i+j,k,:][plot_val_label_mask[i+j,k,:]],
                             plot_val_label[i+j,k,:][plot_val_label_mask[i+j,k,:]],
                             'g-',label=heads_name_day[i+j]+'_real',markersize=4)
                    
                    axs[j].plot(plot_val_t[i+j,k,:],total_val_outputs[i+j,k,:],
                             'r--',label=heads_name_day[i+j]+'_val')
                else:
                    axs[j].plot(plot_train_t[i+j,k,:][plot_train_label_mask[i+j,k,:]],
                             plot_train_label[i+j,k,:][plot_train_label_mask[i+j,k,:]],
                             'g-',markersize=4)
                    
                    axs[j].plot(plot_train_t[i+j,k,:],total_train_outputs[i+j,k,:],
                             'b--')
                    
                    axs[j].plot(plot_val_t[i+j,k,:][plot_val_label_mask[i+j,k,:]],
                             plot_val_label[i+j,k,:][plot_val_label_mask[i+j,k,:]],
                             'g-',markersize=4)
                    
                    axs[j].plot(plot_val_t[i+j,k,:],total_val_outputs[i+j,k,:],
                             'r--')
                
            if j%3==0:
                axs[j].set_ylabel('head/m')
            if j>=6:
                axs[j].set_xlabel('time')
                
            axs[j].set_title(heads_name_day[i+j]+'_head time series')
            axs[j].legend(loc='upper right',framealpha=0.9)
            
        path_t=path+f'/{dataset_name}_{i//9}.png'
        plt.savefig(path_t, bbox_inches='tight',dpi=200)
        plt.clf()
        plt.close()


#%%
def plot_conc(args,model,path,device,dataset_name):
	#%%
    train_input,train_label1,train_label1_coe,train_label2,train_label2_mask,\
    test_input,test_label1,test_label1_coe,test_label2,test_label2_mask,\
    mmin_list,dev_list,conc_name,edge_list,time_scaler=PrepareConcDataset(args['time_window'],args['res'],
                                                                          args['pred_day'],args['buffer'],
                                                                          args['gap'],args['test_day'],
                                                                          args['val_ratio'],flag='plot')
      
    # plot_conc_spatial(args,model,train_input,
    #                   conc_name,time_scaler,
    #                   path,device,dev_list,mmin_list,edge_list,dataset_name)    

    index_node=np.array(edge_list[:,1:],dtype=np.int32)
    model.eval()
    total_train_outputs=[]
    for i in range(0,train_input.shape[0],args['batch_size']):
        inputs= train_input[i:i+args['batch_size']]
        inputs=Variable(inputs.to(device=device))
        outputs= model(inputs)
        #B_T_C_H_W->B_T_C_N->N_B_T_F->numpy
        total_train_outputs.append(outputs[:,:,:,index_node[:,0],index_node[:,1]].permute(3,0,1,2)[:20,:,:,:].cpu().detach().numpy())
    
    # numpy->N_B_T_F->N_B_F
    total_train_outputs=np.concatenate(total_train_outputs,axis=1).squeeze(2)


    total_test_outputs=[]
    for i in range(0,test_input.shape[0],args['batch_size']):
        inputs= test_input[i:i+args['batch_size']]
        inputs=Variable(inputs.to(device=device))
        outputs= model(inputs)
        #B_T_C_H_W->B_T_C_N->N_B_T_F->numpy
        total_test_outputs.append(outputs[:,:,:,index_node[:,0],index_node[:,1]].permute(3,0,1,2)[:20,:,:,:].cpu().detach().numpy())
    
    # numpy->N_B_T_F->N_B_F
    total_test_outputs=np.concatenate(total_test_outputs,axis=1).squeeze(2)


    #B_T_C_H_W->B_T_C_N->N_B_T_C->N_B_T->N_B
    plot_train_t=train_input[:,args['buffer']:,:,:,:][:,:,:,index_node[:,0],index_node[:,1]].permute(3,0,1,2)[:20,:,:,0].numpy().squeeze(-1)
    #B_T_C_H_W->B_T_C_N->N_B_T_C->N_B_T->N_B
    plot_test_t=test_input[:,args['buffer']:,:,:,:][:,:,:,index_node[:,0],index_node[:,1]].permute(3,0,1,2)[:20,:,:,0].numpy().squeeze(-1)
    
    plot_train_t=plot_train_t*time_scaler.data_range_+time_scaler.data_min_
    plot_test_t=plot_test_t*time_scaler.data_range_+time_scaler.data_min_
    plot_train_order=np.argsort(plot_train_t[0,:])
    plot_test_order=np.argsort(plot_test_t[0,:])
    
    plot_train_t=plot_train_t[:,plot_train_order]
    plot_test_t=plot_test_t[:,plot_test_order]
    
    #B_T_C_H_W->B_T_C_N->N_B_T_F->B_T_F->B_F
    plot_train_label1=train_label1.squeeze(-1).permute(3,0,1,2).numpy().squeeze(0).squeeze(1)[plot_train_order,:]
    #B_T_C_H_W->B_T_C_N->N_B_T_F->N_B_F,
    plot_train_label1_coe=train_label1_coe[:,:,:,index_node[:,0],index_node[:,1]].permute(3,0,1,2)[:20,:,:,:].numpy().squeeze(2)[:,plot_train_order,:]
    #B_T_C_H_W->B_T_C_N->N_B_T_F->N_B_F
    plot_train_label2=train_label2[:,:,:,index_node[:,0],index_node[:,1]].permute(3,0,1,2)[:20,:,:,:].numpy().squeeze(2)[:,plot_train_order,:]
    #B_T_C_H_W->B_T_C_N->N_B_T_F->N_B_F
    plot_train_label2_mask=train_label2_mask[:,:,:,index_node[:,0],index_node[:,1]].permute(3,0,1,2)[:20,:,:,:].numpy().squeeze(2)[:,plot_train_order,:]
    #N_B_F->sort->N_B_F
    total_train_outputs=total_train_outputs[:,plot_train_order,:]
    
    #N_B_F->B_F
    temp_train=np.concatenate([(total_train_outputs[:,:,i:i+1]*plot_train_label1_coe).sum(axis=0) 
                               for i in range(total_train_outputs.shape[-1])],axis=-1)
    
    #B_F->F
    train_label1_mae=abs(temp_train-plot_train_label1).mean(axis=0)*dev_list
    #N_B_F->N_F
    train_label2_mae=[]
    for i in range(total_train_outputs.shape[0]):
        temp_train2_mae=[]
        for j in range(total_train_outputs.shape[-1]):
            temp_train2_mae.append(abs(total_train_outputs[i,:,j]-plot_train_label2[i,:,j])[plot_train_label2_mask[i,:,j]].mean()*dev_list[j])
        train_label2_mae.append(temp_train2_mae)
    train_label2_mae=np.array(train_label2_mae)
    
    
    #B_T_C_H_W->B_T_C_N->N_B_T_F->B_T_F->B_F
    plot_test_label1=test_label1.squeeze(-1).permute(3,0,1,2).numpy().squeeze(0).squeeze(1)[plot_test_order,:]
    #B_T_C_H_W->B_T_C_N->N_B_T_F->N_B_F
    plot_test_label1_coe=test_label1_coe[:,:,:,index_node[:,0],index_node[:,1]].permute(3,0,1,2)[:20,:,:,:].numpy().squeeze(2)[:,plot_test_order,:]
    #B_T_C_H_W->B_T_C_N->N_B_T_F->N_B_F
    plot_test_label2=test_label2[:,:,:,index_node[:,0],index_node[:,1]].permute(3,0,1,2)[:20,:,:,:].numpy().squeeze(2)[:,plot_test_order,:]
    #B_T_C_H_W->B_T_C_N->N_B_T_F->N_B_F
    plot_test_label2_mask=test_label2_mask[:,:,:,index_node[:,0],index_node[:,1]].permute(3,0,1,2)[:20,:,:,:].numpy().squeeze(2)[:,plot_test_order,:]
    #N_B_F->sort->N_B_F
    total_test_outputs=total_test_outputs[:,plot_test_order,:]
    
    #N_B_F->B_F
    temp_test=np.concatenate([(total_test_outputs[:,:,i:i+1]*plot_test_label1_coe).sum(axis=0) 
                               for i in range(total_test_outputs.shape[-1])],axis=-1)
    
    #B_F->F
    test_label1_mae=abs(temp_test-plot_test_label1).mean(axis=0)*dev_list
    #N_B_F->N_F
    test_label2_mae=[]
    for i in range(total_test_outputs.shape[0]):
        temp_test2_mae=[]
        for j in range(total_test_outputs.shape[-1]):
            temp_test2_mae.append(abs(total_test_outputs[i,:,j]-plot_test_label2[i,:,j])[plot_test_label2_mask[i,:,j]].mean()*dev_list[j])
        test_label2_mae.append(temp_test2_mae)
    test_label2_mae=np.array(test_label2_mae)
    
    #plot_label1_mae(plot_train_label1,plot_test_label1,temp_train,temp_test,dev_list,dataset_name,path)

    #%%
    feature_names=['U','H']
    unit=['mg/l','g/l']
    #N_B_T_F
    
    sig_train_t=plot_train_t[plot_train_label2_mask[...,0]].reshape(20,-1)
    sig_train_U=plot_train_label2[plot_train_label2_mask[...,0]].reshape(20,-1)*dev_list[0]+mmin_list[0]
    sig_test_t=plot_test_t[plot_test_label2_mask[...,0]].reshape(20,-1)
    sig_test_U=plot_test_label2[plot_test_label2_mask[...,0]].reshape(20,-1)*dev_list[0]+mmin_list[0]
    
    write_pkl(sig_train_t, path+'抽液井单井训练阶段_5次监测时间.pkl')
    write_pkl(sig_train_U,path+'抽液井单井训练阶段_5次监测铀浓度.pkl')
    write_pkl(sig_test_t,path+'抽液井单井测试阶段_1次监测时间.pkl')
    write_pkl(sig_test_U,path+'抽液井单井测试阶段_1次监测铀浓度.pkl')
    for i in range(plot_train_label1.shape[-1]):
        for j in range(0,plot_train_label2.shape[0]-2,9):
            fig,axs = plt.subplots(3,3,figsize=(18,18),dpi=200,sharex=True, sharey=True)
            axs=axs.reshape(-1)
            for k in range(9):
                
                axs[k].plot(plot_train_t[j+k,:][plot_train_label2_mask[j+k,:,i]],
                         plot_train_label2[j+k,:,i][plot_train_label2_mask[j+k,:,i]]*dev_list[i]+mmin_list[i],
                         'go',label=conc_name[j+k]+'_'+feature_names[i]+'_real',markersize=4)
                
                axs[k].plot(plot_train_t[j+k,:],total_train_outputs[j+k,:,i]*dev_list[i]+mmin_list[i],
                         'b-',label=conc_name[j+k]+'_'+feature_names[i]+'_train')

                axs[k].plot(plot_test_t[j+k,:][plot_test_label2_mask[j+k,:,i]],
                         plot_test_label2[j+k,:,i][plot_test_label2_mask[j+k,:,i]]*dev_list[i]+mmin_list[i],
                         'go',markersize=4)
                
                axs[k].plot(plot_test_t[j+k,:],total_test_outputs[j+k,:,i]*dev_list[i]+mmin_list[i],
                         'r-',label=conc_name[j+k]+'_'+feature_names[i]+'_test')

                axs[k].text(0.01,0.95,f'train MAE={train_label2_mae[j+k,i]:.3f}', 
                            transform=axs[k].transAxes)
                axs[k].text(0.01,0.85,f'test MAE={test_label2_mae[j+k,i]:.3f}', 
                            transform=axs[k].transAxes)

                if k%3==0:
                    axs[k].set_ylabel(feature_names[i]+'_concentration '+unit[i])
                if k>=6:
                    axs[k].set_xlabel('time')
                    
                axs[k].set_title(conc_name[j+k]+'_'+feature_names[i]+' time series')
                axs[k].legend(loc='best',framealpha=0.9)
     
            path_t=path+f'/{dataset_name}_all_{feature_names[i]}_{j//9}.png'
            plt.savefig(path_t, bbox_inches='tight',dpi=200)
            plt.clf()
            plt.close()
    #%%         
    for i in range(plot_train_label1.shape[-1]):
        for j in range(0,plot_train_label2.shape[0]-2,9):
            fig,axs = plt.subplots(3,3,figsize=(18,18),dpi=200,sharex=True, sharey=True)
            axs=axs.reshape(-1)
            
            for k in range(9):
                    
                axs[k].plot(total_train_outputs[j+k,:,i][plot_train_label2_mask[j+k,:,i]]*dev_list[i]+mmin_list[i],
                         plot_train_label2[j+k,:,i][plot_train_label2_mask[j+k,:,i]]*dev_list[i]+mmin_list[i],
                         'bo',label=conc_name[j+k]+'_'+feature_names[i]+'_train',markersize=4)
                
                axs[k].plot(total_test_outputs[j+k,:,i][plot_test_label2_mask[j+k,:,i]]*dev_list[i]+mmin_list[i],
                         plot_test_label2[j+k,:,i][plot_test_label2_mask[j+k,:,i]]*dev_list[i]+mmin_list[i],
                         'go',label=conc_name[j+k]+'_'+feature_names[i]+'_test',markersize=4)


                axs[k].text(0.01,0.95,f'train MAE={train_label2_mae[j+k,i]:.3f}', 
                            transform=axs[k].transAxes)
                axs[k].text(0.01,0.85,f'test MAE={test_label2_mae[j+k,i]:.3f}', 
                            transform=axs[k].transAxes)
                
                axs[k].plot(np.linspace((total_train_outputs.min()*dev_list[i]+mmin_list[i])*0.99,
                                        (total_train_outputs.max()*dev_list[i]+mmin_list[i])*1.01,100),
                            np.linspace((total_train_outputs.min()*dev_list[i]+mmin_list[i])*0.99,
                                        (total_train_outputs.max()*dev_list[i]+mmin_list[i])*1.01,100),
                            label='1:1',c='k',alpha=0.5)

                if k%3==0:
                    axs[k].set_ylabel('real')
                if k>=6:
                    axs[k].set_xlabel('predict')
                    
                axs[k].set_title(conc_name[j+k]+'_'+feature_names[i]+' time series')
                axs[k].legend(loc='lower right',framealpha=0.9)
     
            path_t=path+f'/{dataset_name}_all_deviation_{feature_names[i]}_{j//9}.png'
            plt.savefig(path_t, bbox_inches='tight',dpi=200)
            plt.clf()
            plt.close()

    #%%
        fig = plt.figure(figsize=(10,6),dpi=200)
        ax = fig.add_subplot(111)
        ax.plot(plot_train_t[j,:],
                 plot_train_label1[:,i]*dev_list[i]+mmin_list[i],
                 'go',label='sum'+'_'+feature_names[i]+'_real',markersize=4)
        write_pkl(np.round(plot_train_t[j,:]).astype(int),path+'训练阶段总铀浓度序列时间.pkl')
        write_pkl(plot_train_label1[:,i]*dev_list[i]+mmin_list[i],path+'训练阶段总铀浓度监测序列.pkl')
        ax.plot(plot_train_t[j,:],temp_train[:,i]*dev_list[i]+mmin_list[i],
                 'b-',label='sum'+'_'+feature_names[i]+'_train')

        write_pkl(temp_train[:,i]*dev_list[i]+mmin_list[i],path+'训练阶段总铀浓度预测序列.pkl')
        ax.plot(plot_test_t[j,:],
                 plot_test_label1[:,i]*dev_list[i]+mmin_list[i],
                 'go',markersize=4)
        
        ax.plot(plot_test_t[j,:],temp_test[:,i]*dev_list[i]+mmin_list[i],
                 'r-',label='sum'+'_'+feature_names[i]+'_test')
        write_pkl(np.round(plot_test_t[j,:]).astype(int),path+'测试阶段总铀浓度序列时间.pkl')
        write_pkl(plot_test_label1[:,i]*dev_list[i]+mmin_list[i],path+'测试阶段总铀浓度监测序列.pkl')
        
        write_pkl(temp_test[:,i]*dev_list[i]+mmin_list[i],path+'测试阶段总铀浓度预测序列.pkl')
        
                                
        ax.text(0.01,0.95,f'train MAE={train_label1_mae[i]:.3f}', 
                transform=ax.transAxes)
        ax.text(0.01,0.85,f'test MAE={test_label1_mae[i]:.3f}', 
                transform=ax.transAxes)
        ax.set_ylabel(feature_names[i]+'_concentration '+unit[i])
        ax.set_xlabel('time')
        ax.set_title('sum'+'_'+feature_names[i]+' time series')
        # plt.legend(loc='upper right',framealpha=0.9)
        path_t=path+f'/{dataset_name}_sum_{feature_names[i]}.png'
        plt.savefig(path_t, bbox_inches='tight',dpi=200)
        plt.clf()
        plt.close()
#%%
        fig = plt.figure(figsize=(5,5),dpi=200)
        ax = fig.add_subplot(111)
        ax.plot(temp_train[:,i]*dev_list[i]+mmin_list[i],
                 plot_train_label1[:,i]*dev_list[i]+mmin_list[i],
                 'bo',label='sum'+'_'+feature_names[i]+'_train',markersize=4)
        
        ax.plot(temp_test[:,i]*dev_list[i]+mmin_list[i],
                 plot_test_label1[:,i]*dev_list[i]+mmin_list[i],
                 'ro',label='sum'+'_'+feature_names[i]+'_test',markersize=4)

        ax.plot(np.linspace((plot_train_label1.min()*dev_list[i]+mmin_list[i])*0.99,
                            (plot_train_label1.max()*dev_list[i]+mmin_list[i])*1.01,100),
                np.linspace((plot_train_label1.min()*dev_list[i]+mmin_list[i])*0.99,
                            (plot_train_label1.max()*dev_list[i]+mmin_list[i])*1.01,100),
                label='1:1',c='k',alpha=0.5)
        ax.text(0.01,0.95,f'train MAE={train_label1_mae[i]:.3f}', 
                transform=ax.transAxes)
        ax.text(0.01,0.85,f'test MAE={test_label1_mae[i]:.3f}', 
                transform=ax.transAxes)
        ax.set_ylabel('real')
        ax.set_xlabel('predict')
        ax.set_title('sum'+'_'+feature_names[i]+' time series')
        plt.legend(loc='lower right',framealpha=0.9)
        path_t=path+f'/{dataset_name}_sum_deviation_{feature_names[i]}.png'
        plt.savefig(path_t, bbox_inches='tight',dpi=200)
        plt.clf()
        plt.close()
    #%%    
    return test_label1_mae.mean()



#%%
def plot_conc_spatial(args,model,train_input,
                      conc_name,time_scaler,
                      path,device,dev_list,mmin_list,edge_list,dataset_name):
    
    model.eval()
    index_node=np.array(edge_list[:,1:],dtype=np.int32)
    total_train_outputs=[]
    for i in range(0,train_input.shape[0],args['batch_size']):
        inputs= train_input[i:i+args['batch_size']]
        inputs=Variable(inputs.to(device=device))
        outputs= model(inputs)
        #B_T_C_H_W
        total_train_outputs.append(outputs.cpu().detach().numpy())
    
    #B_T_C_H_W->T_C_H_W
    total_train_outputs=np.concatenate(total_train_outputs,axis=0)
    total_train_outputs=total_train_outputs.reshape(total_train_outputs.shape[0]*total_train_outputs.shape[1],
                                                    total_train_outputs.shape[2],total_train_outputs.shape[3],
                                                    total_train_outputs.shape[4])*dev_list[0]+mmin_list[0]

    #B_T_C_H_W->B_T_H_W->B_T->T_H_W
    plot_train_t=train_input[:,args['buffer']:,:,:,:][:,:,0,:,:].numpy()[:,:,index_node[0,0],index_node[0,1]]
    plot_train_t=plot_train_t.reshape(plot_train_t.shape[0]*plot_train_t.shape[1])
    train_index=plot_train_t.argsort()
    plot_train_t=plot_train_t[train_index]
    plot_train_t=plot_train_t*time_scaler.data_range_+time_scaler.data_min_
    total_train_outputs=total_train_outputs[train_index]

    feature_names=['U','H']
    unit=['mg/l','g/l']
    selected_time_train_index=np.floor(np.linspace(0,total_train_outputs.shape[0]-1,18)).astype(np.int32)
    selected_time_train=plot_train_t[selected_time_train_index]

    for i in range(total_train_outputs.shape[1]):
        fig,axs = plt.subplots(3,6,figsize=(36,18),dpi=200,sharex=True, sharey=True)
        axs=axs.reshape(-1)
        for j in range(len(selected_time_train)):
            pcm=axs[j].pcolormesh(total_train_outputs[selected_time_train_index[j],i,:,:],cmap='rainbow')
            axs[j].set_title(feature_names[i]+f'_{selected_time_train[j]:.1f}day saptial')
            fig.colorbar(pcm,ax=axs[j])
            for lm in range(len(index_node)):
                axs[j].scatter(index_node[lm,0]+0.5,index_node[lm,1]+0.5,c='k')

        path_t=path+f'/{dataset_name}_spatial_train_{feature_names[i]}.png'
        plt.savefig(path_t, bbox_inches='tight',dpi=200)
        plt.clf()
        plt.close()

#%%

def plot_label1_mae(plot_train_label1,plot_test_label1,temp_train,temp_test,
                    dev_list,dataset_name,path):
    train_mae=abs(plot_train_label1-temp_train)*dev_list
    test_mae=abs(plot_test_label1-temp_test)*dev_list
    
    feature_names=['U','H']
    unit=['mg/l','g/l']
    for i in range(train_mae.shape[-1]):
        fig,ax0=plt.subplots(1,1,figsize=(5,5),dpi=200)
        ax0.hist(train_mae[:,i],bins=7,color='C1',alpha=0.8,label='train',density=True,cumulative=True)
        ax0.hist(test_mae[:,i],bins=7,color='C2',alpha=0.8,label='test',density=True,cumulative=True)
        plt.legend(['train','test'],loc='upper left',bbox_to_anchor=(1.01,1.02),framealpha=0.9,)
        path_t=path+f'/{dataset_name}_error_distribution_train_{feature_names[i]}.png'
        plt.savefig(path_t, bbox_inches='tight',dpi=200)
        plt.clf()
        plt.close()





